-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-876] make CachedOp a normal operator #11641
Conversation
include/mxnet/c_api.h
Outdated
@@ -1073,6 +1073,9 @@ MXNET_DLL int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **inputs, | |||
MXNET_DLL int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **inputs, | |||
int *input_size); | |||
|
|||
int MXMakeSubgraph(SymbolHandle sym, SymbolHandle *input_symbols, mx_uint num_inputs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doc missing
python/mxnet/symbol/contrib.py
Outdated
@@ -336,3 +336,10 @@ def check_data(inputs, in_type, msg): | |||
states = states[0] | |||
|
|||
return (outs, states) | |||
|
|||
def make_subgraph(subg, *args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doc missing
src/c_api/c_api_symbolic.cc
Outdated
// Construct a node for this subgraph. | ||
std::vector<nnvm::NodeEntry> inputs(num_inputs); | ||
for (size_t i = 0; i < inputs.size(); i++) { | ||
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(input_symbols[i]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can it be const?
src/c_api/c_api_symbolic.cc
Outdated
|
||
// Create CachedOp for the node. | ||
std::vector<std::pair<std::string, std::string> > kwargs; | ||
kwargs.push_back(std::pair<std::string, std::string>("inline_limit", "0")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not emplace back? More efficient and less noisy...
src/c_api/c_api_symbolic.cc
Outdated
n->attrs.parsed = std::make_shared<mxnet::CachedOp>(*s, kwargs); | ||
|
||
// Create a new symbol for this node. | ||
s = new nnvm::Symbol(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this leak? Who manages this one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i just follow the implementations in other APIs. The symbol will be saved in a Python symbol handle. AFAIK, once the python symbol handle is destroyed, the symbol object will also be destroyed.
src/imperative/cached_op.cc
Outdated
std::shared_ptr<CachedOp> op; | ||
OpStatePtr forward_state; | ||
|
||
CachedOpActualState(std::shared_ptr<CachedOp> op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By reference?
src/imperative/cached_op.cc
Outdated
@@ -1047,6 +1047,105 @@ void CachedOp::Backward( | |||
Engine::Get()->set_bulk_size(prev_bulk_size); | |||
} | |||
|
|||
struct CachedOpActualState { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing class doc
src/imperative/cached_op.cc
Outdated
} | ||
}; | ||
|
||
void CachedOpForward(const OpStatePtr& state_ptr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing a short documentation stating what's the intention and how it works
src/imperative/cached_op.cc
Outdated
const std::vector<bool> &save_outputs = s.op->save_outputs(); | ||
CHECK_EQ(save_inputs.size(), in_end - in_begin); | ||
CHECK_EQ(s.op->num_outputs(), out_end - out_begin); | ||
for (auto it = in_begin; it != in_end; it++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
++it is potentially faster
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
really? Where is this documented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zheng-da it's a well known thing for old C++ farts. It's in reference C++ books like http://www.cppstdlib.com/ or Stroustrup. https://stackoverflow.com/questions/1077026/incrementing-iterators-it-more-efficient-than-it
In most cases probably doesn't make a difference, specially for simple iterators where the iterator is just a pointer. That's why I said is potentially faster. It's more like a good idiomatic practice to always use preincrement.
https://stackoverflow.com/questions/1077026/incrementing-iterators-it-more-efficient-than-it
src/imperative/cached_op.h
Outdated
@@ -116,6 +116,24 @@ class CachedOp { | |||
DispatchMode* dispatch_mode, | |||
std::vector<int> *in_attrs, | |||
std::vector<int> *out_attrs); | |||
bool ForwardInferShape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing documentation in function prototypes
'b': mx.nd.empty(shape=(10, 10))}) | ||
e1.forward() | ||
e2.forward() | ||
assert_almost_equal(e1.outputs[0].asnumpy(), e2.outputs[0].asnumpy(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why almost equal and not equal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think due to floating point precision
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added comments.
@larroy thanks for the review. The code, especially the API design, is still experimental. I'll let you know when the code is ready for review. |
I think the function like shape, type inferences, FMutate, etc. used in operator registration should not belong to CachedOp only. They should be made generally available to subgraph-type operators, while CachedOp is just a special case. |
6a03241
to
476fa57
Compare
0b5df8b
to
910cc05
Compare
This PR should be rebased to #12157 |
cea32a3
to
3b616c2
Compare
src/operator/subgraph/common.h
Outdated
inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs, | ||
std::vector<TShape> *in_shapes, | ||
std::vector<TShape> *out_shapes) { | ||
return DefaultSubgraphOpShape1(*attrs.subgraphs[0], in_shapes, out_shapes); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename DefaultSubgraphOpShape1
to something like a helper function for better readability?
const auto& idx = g.indexed_graph(); | ||
const auto &outputs = idx.outputs(); | ||
/* | ||
* This is the operator state of CachedOp when CachedOp is used in the symbol |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please elaborate on the necessity of adding this data structure in the description.
// Clean up what we recorded. | ||
s.forward_state.reset(); | ||
|
||
// The arrays in out_ptrs may be changed by CachedOp. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why would it be changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updating the comments
src/imperative/cached_op.cc
Outdated
else | ||
orig_is_train = Imperative::Get()->is_training(); | ||
// TODO(zhengda) is it right to use false here? | ||
s.op->Backward(false, s.forward_state, in_ptrs, req, out_ptrs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add more comment on retain_graph=False
Description
Currently, CachedOp is used to execute the graph in a Gluon hybrid block when the block is hybridized. It's registered as an operator, but it doesn't have a full set of operator attributes. So it can't be used as a regular operator and can't be used in a normal NNVM computation graph. This PR is to extend CachedOp and make it a normal operator. The main motivation is to use it as a default subgraph operator, as proposed in Unified integration with external acceleration libraries.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments